{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 裁剪 YOLOv8n 以探索计算图分割"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.dataflow_pattern import rewrite\n",
    "from tvm_book.transforms.yolo import Dist2xywhSimplify\n",
    "import tvm\n",
    "\n",
    "with tvm.transform.PassContext(opt_level=3, disabled_pass={\"AlterOpLayout\"}):\n",
    "    run_mod = tvm.IRModule.from_expr(rewrite(Dist2xywhSimplify(), mod[\"main\"]))\n",
    "    lib = relay.build(run_mod, target=\"llvm\", params=params)\n",
    "\n",
    "func = lib[lib.libmod_name]\n",
    "module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))\n",
    "module.run(**{input_name: xs})\n",
    "num_outputs = module.get_num_outputs()\n",
    "run_float_outputs = [module.get_output(k).numpy() for k in range(num_outputs)]\n",
    "[\n",
    "    np.testing.assert_allclose(a, b, rtol=1e-07, atol=1e-3)\n",
    "    for a, b in zip(float_outputs, run_float_outputs)\n",
    "]\n",
    "results = postprocess(\n",
    "    [torch.from_numpy(o) for o in run_float_outputs], \n",
    "    xs, [origin_image], self.model.names, \n",
    "    input_path, conf_thres=0.25, iou_thres=0.45,\n",
    ")\n",
    "Image.fromarray(results[0].plot())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "from tvm.relay.analysis import extract_intermdeiate_expr\n",
    "from tvm_book.compiler.utils import merge_compiler\n",
    "\n",
    "run_mod = deepcopy(mod) \n",
    "# run_mod = extract_intermdeiate_expr(run_mod, 110)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    run_mod[\"main\"] = rewrite(Dist2xywhSimplify(), run_mod[\"main\"])\n",
    "    run_mod = relay.quantize.prerequisite_optimize(run_mod, params)\n",
    "    run_mod = merge_compiler(run_mod, compiler_name=\"vta_special\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(run_mod[\"main\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "\n",
    "ENV = {\n",
    "    \"model_type\": \"onnx\",\n",
    "    \"input_name\": \"images\",\n",
    "    \"channel\": 3,\n",
    "    \"height\": 640, \n",
    "    \"width\": 640,\n",
    "    \"mode\": \"RGB\", # 输入图片格式\n",
    "    \"mean\": (0,),\n",
    "    \"std\": (255,)\n",
    "}\n",
    "\n",
    "def letterbox_image(im: Image, dst_width: int, dst_height: int):\n",
    "    '''使用填充保持纵横比缩放图像\n",
    "    \n",
    "    Args:\n",
    "        im: 原始 Image\n",
    "        dst_width: 目标宽度\n",
    "        dst_height: 目标高度\n",
    "    '''\n",
    "    iw, ih = im.size # 原始尺寸\n",
    "    scale = min(dst_width/iw, dst_height/ih)\n",
    "    nw = int(iw*scale)\n",
    "    nh = int(ih*scale)\n",
    "    im = im.resize((nw, nh), Image.BICUBIC)\n",
    "    new_image = Image.new('RGB', (dst_width, dst_height), (114, 114, 114))\n",
    "    new_image.paste(im, ((dst_width-nw)//2, (dst_height-nh)//2))\n",
    "    return new_image\n",
    "\n",
    "def preprocessing(path: str|None, **ENV: dict):\n",
    "    if not path:\n",
    "        im = np.random.randint(0, 256, size=(32, 32, 3), dtype=\"uint8\")\n",
    "        im = Image.fromarray(im) # 转为 Image 实例\n",
    "    else:\n",
    "        im = Image.open(path)\n",
    "    # im = im.resize((ENV[\"width\"], ENV[\"height\"]), Image.BICUBIC)\n",
    "    im = letterbox_image(im, ENV[\"width\"], ENV[\"height\"])\n",
    "    if ENV[\"mode\"] == \"L\": # 将灰度图转换为 HWC 布局\n",
    "        img = im.convert(\"L\")\n",
    "        img = np.expand_dims(img, axis=-1) # 转为 HWC\n",
    "    elif ENV[\"mode\"] == \"RGB\":\n",
    "        img = np.array(im.convert(\"RGB\")) # 转为 HWC 布局\n",
    "    elif ENV[\"mode\"] == \"BGR\":\n",
    "        img = np.array(im.convert(\"RGB\")) # 转为 HWC 布局\n",
    "        img = img[..., ::-1] # RGB 转 BGR\n",
    "    else:\n",
    "        raise TypeError(f'暂未支持数据布局 {ENV[\"mode\"]}')\n",
    "    image_np = np.expand_dims(img, 0) # 转换为 NHWC (uint8 数据)\n",
    "    # 预处理后的数据\n",
    "    data_inp = ((image_np - ENV[\"mean\"]) / ENV[\"std\"]).astype(np.float32)\n",
    "    data_inp = data_inp.transpose(0, 3, 1, 2)\n",
    "    return np.ascontiguousarray(image_np), np.ascontiguousarray(data_inp)\n",
    "\n",
    "def calibrateset(calibrate_num=2, data_dir=\"/media/pc/data/lxw/home/data/coco/train2017\"):\n",
    "    \"\"\"用于量化的校准数据集\"\"\"\n",
    "    for k, path in tqdm(enumerate(Path(data_dir).iterdir()), desc=\"Calibrate\", unit=\"batch\"):\n",
    "        if k >= calibrate_num:\n",
    "            break\n",
    "        yield {ENV[\"input_name\"]: preprocessing(path, **ENV)[1]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_mod = deepcopy(mod)\n",
    "with tvm.transform.PassContext(opt_level=3, disabled_pass={\"AlterOpLayout\"}):\n",
    "    run_mod[\"main\"] = rewrite(Dist2xywhSimplify(), run_mod[\"main\"])\n",
    "    with relay.quantize.qconfig(\n",
    "        calibrate_mode=\"percentile\", weight_scale=\"max\"):\n",
    "        qmod = relay.quantize.quantize(run_mod, params, dataset=calibrateset())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qmod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.dataflow_pattern import rewrite\n",
    "from tvm_book.transforms.yolo import Dist2xywhSimplify\n",
    "import tvm\n",
    "\n",
    "with tvm.transform.PassContext(opt_level=3, disabled_pass={\"AlterOpLayout\"}):\n",
    "    lib = relay.build(qmod, target=\"llvm\", params=params)\n",
    "\n",
    "func = lib[lib.libmod_name]\n",
    "module = tvm.contrib.graph_executor.GraphModule(func(tvm.cpu(0)))\n",
    "module.run(**{input_name: xs})\n",
    "num_outputs = module.get_num_outputs()\n",
    "quant_outputs = [module.get_output(k).numpy() for k in range(num_outputs)]\n",
    "results = postprocess(\n",
    "    [torch.from_numpy(o) for o in quant_outputs], \n",
    "    xs, [origin_image], self.model.names, \n",
    "    input_path, conf_thres=0.25, iou_thres=0.45,\n",
    ")\n",
    "Image.fromarray(results[0].plot())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ww"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.analysis import _ffi_api\n",
    "\n",
    "output_map = _ffi_api.get_calibrate_output_map(run_mod)\n",
    "calibrate_mod = _ffi_api.get_calibrate_module(run_mod)\n",
    "calibrate_mod = relay.transform.Inline()(calibrate_mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_res = relay.build_module.create_executor(\"graph\", mod=calibrate_mod, device=tvm.cpu(0)).evaluate()(**{input_name: xs})\n",
    "\n",
    "calib_data = {}\n",
    "for gvar, indices in output_map.items():\n",
    "    offset = int(indices[0])\n",
    "    in_len = int(indices[1])\n",
    "    out_len = int(indices[2])\n",
    "    value = {\n",
    "        \"inputs\": ref_res[offset : offset + in_len],\n",
    "        \"outputs\": ref_res[offset + in_len : offset + in_len + out_len],\n",
    "    }\n",
    "    calib_data[gvar] = value\n",
    "func_map = {int(kk.name_hint.split(\"_\")[-1]): kk for kk in calib_data.keys()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calib_data[func_map[len(func_map)-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
